import torch
from itertools import permutations

import config

EPS = 1e-8


def reorder_source(source, perms, max_snr_idx):
    """
    Args:
        source: [B, C, T]
        perms: [C!, C], permutations
        max_snr_idx: [B], each item is between [0, C!)
    Returns:
        reorder_source: [B, C, T]
    """
    B, C, *_ = source.size()
    # [B, C], permutation whose SI-SNR is max of each utterance
    # for each utterance, reorder estimate source according this permutation
    max_snr_perm = torch.index_select(perms, dim=0, index=max_snr_idx)
    # print('max_snr_perm', max_snr_perm)
    # maybe use torch.gather()/index_select()/scatter() to impl this?
    reorder_source = torch.zeros_like(source)
    for b in range(B):
        for c in range(C):
            reorder_source[b, c] = source[b, max_snr_perm[b][c]]
    return reorder_source


def loss(source, estimate_source, device="cuda"):
    """
    source shape : B, C, T, 
    estimate source : B, C, T
    """
    source_lengths = torch.full(
        (estimate_source.size(0), 1), estimate_source.size(2)
    ).to(device)
    return cal_loss(source, estimate_source, source_lengths)


def cal_loss(source, estimate_source, source_lengths):
    """
    Args:
        source: [B, C, T], B is batch size
        estimate_source: [B, C, T]
        source_lengths: [B]
    """
    max_snr, perms, max_snr_idx = cal_si_snr_with_pit(
        source, estimate_source, source_lengths
    )
    loss = 0 - torch.mean(max_snr)
    return loss


def cal_si_snr_with_pit(source, estimate_source, source_lengths):
    """Calculate SI-SNR with PIT training.
    Args:
        source: [B, C, T], B is batch size
        estimate_source: [B, C, T]
        source_lengths: [B], each item is between [0, T]
    """
    assert source.size() == estimate_source.size()
    B, C, T = source.size()
    # mask padding position along T
    mask = get_mask(source, source_lengths)
    estimate_source *= mask

    # Step 1. Zero-mean norm
    num_samples = source_lengths.view(-1, 1, 1).float()  # [B, 1, 1]
    mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
    mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate
    # mask padding position along T
    zero_mean_target *= mask
    zero_mean_estimate *= mask

    # Step 2. SI-SNR with PIT
    # reshape to use broadcast
    s_target = torch.unsqueeze(zero_mean_target, dim=1)  # [B, 1, C, T]
    s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2)  # [B, C, 1, T]
    print((s_estimate * s_target).shape)
    # s_target = <s', s>s / ||s||^2
    pair_wise_dot = torch.sum(
        s_estimate * s_target, dim=3, keepdim=True
    )  # [B, C, C, 1]
    s_target_energy = (
        torch.sum(s_target**2, dim=3, keepdim=True) + EPS
    )  # [B, 1, C, 1]
    pair_wise_proj = pair_wise_dot * s_target / s_target_energy  # [B, C, C, T]
    # e_noise = s' - s_target
    e_noise = s_estimate - pair_wise_proj  # [B, C, C, T]
    # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    pair_wise_si_snr = torch.sum(pair_wise_proj**2, dim=3) / (
        torch.sum(e_noise**2, dim=3) + EPS
    )
    pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS)  # [B, C, C]

    # Get max_snr of each utterance
    # permutations, [C!, C]
    perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
    # one-hot, [C!, C, C]
    index = torch.unsqueeze(perms, 2)
    perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)
    # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
    snr_set = torch.einsum("bij,pij->bp", [pair_wise_si_snr, perms_one_hot])
    max_snr_idx = torch.argmax(snr_set, dim=1)  # [B]
    # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1))  # [B, 1]
    max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
    max_snr /= C
    return max_snr, perms, max_snr_idx


## TODO : The loss function is literally copied from the internet
def sisnr(output, target):
    def cal_loss(target1, permu_tuple, batch_size, output1):
        new_perm = torch.zeros(target1.shape).to(device)
        for b in range(0, batch_size):
            idx = b * num_spk
            for i in range(0, len(permu_tuple)):
                new_perm[idx + i] = target1[idx + permu_tuple[i]]
                pass
        product = torch.sum(output1 * new_perm, dim=1).unsqueeze(1)
        above = product * new_perm
        l2_square = torch.sum(new_perm * new_perm, dim=1).unsqueeze(1)
        s_t = above / l2_square
        e_n = output1 - s_t
        loss = 10 * torch.log10(
            torch.sum(s_t * s_t, dim=1).unsqueeze(1)
            / (torch.sum(e_n * e_n, dim=1).unsqueeze(1))
        )
        loss = -torch.mean(loss)
        return loss

    batch_size = target.size(0)
    num_spk = target.size(1)

    output1 = output.view(output.size(1) * output.size(0), -1)  # B*C, 32000
    target1 = target.view(target.size(0) * target.size(1), -1)  # B*C, 32000
    ## permutation invariant training
    loss = []
    permute = permutations(range(num_spk))
    for i in list(permute):
        loss.append(cal_loss(target1, i, batch_size, output1))
    loss_min = min(loss)
    return loss_min


def get_mask(source, source_lengths):
    """
    Args:
        source: [B, C, T]
        source_lengths: [B]
    Returns:
        mask: [B, 1, T]
    """
    B, _, T = source.size()
    mask = source.new_ones((B, 1, T))
    for i in range(B):
        mask[i, :, source_lengths[i][0] :] = 0
    return mask


if __name__ == "__main__":
    input = torch.randn(3, 2, 32000)
    real = torch.randn(3, 2, 32000)
    output = loss(input, real, device="cpu")
    print(output)
    pass
